#pragma once

#include <Math/Vector3.hpp>
#include <Math/Vector4.hpp>
#include <Math/Vector.hpp>
#include <Math/Matrix.hpp>
namespace zzz{
template<typename T> class Rotation;
template<typename T>
class Quaternion : public Vector<4,T>
{
public:
  using Vector<4,T>::v;
  Quaternion():Vector<4,T>(0){}
  Quaternion(const VectorBase<3,T> &_axis, const T _angle):Vector<4,T>(_axis[0], _axis[1], _axis[2], 0)
  {
    XYZ().SafeNormalize();
    XYZ() *= sin(_angle/2.0);
    W()=cos(_angle/2.0);
  }
  explicit Quaternion(const VectorBase<4,T> &_v):Vector<4,T>(_v){}
  Quaternion(const T x, const T y, const T z, const T w):Vector<4,T>(x,y,z,w){}
  explicit Quaternion(const MatrixBase<3,3,T> &rot)
  {
    float trace = rot(0,0) + rot(1,1) + rot(2,2);
    if(trace > 0) 
    {
      float s = 0.5f / Sqrt<float>(trace+ 1.0f);
      v[3] = 0.25f / s;
      v[0] = (rot(2,1) - rot(1,2)) * s;
      v[1] = (rot(0,2) - rot(2,0)) * s;
      v[2] = (rot(1,0) - rot(0,1)) * s;
    } 
    else if (rot(0,0) > rot(1,1) && rot(0,0) > rot(2,2)) 
    {
      float s = 2.0f * Sqrt<float>(1.0f + rot(0,0) - rot(1,1) - rot(2,2));
      v[3] = (rot(2,1) - rot(1,2)) / s;
      v[0] = 0.25f * s;
      v[1] = (rot(0,1) + rot(1,0)) / s;
      v[2] = (rot(0,2) + rot(2,0)) / s;
    } 
    else if (rot(1,1) > rot(2,2)) 
    {
      float s = 2.0f * Sqrt<float>(1.0f + rot(1,1) - rot(0,0) - rot(2,2));
      v[3] = (rot(0,2) - rot(2,0)) / s;
      v[0] = (rot(0,1) + rot(1,0)) / s;
      v[1] = 0.25f * s;
      v[2] = (rot(1,2) + rot(2,1)) / s;
    } 
    else 
    {
      float s = 2.0f * Sqrt<float>(1.0f + rot(2,2) - rot(0,0) - rot(1,1));
      v[3] = (rot(1,0) - rot(0,1)) / s;
      v[0] = (rot(0,2) + rot(2,0)) / s;
      v[1] = (rot(1,2) + rot(2,1)) / s;
      v[2] = 0.25f * s;
    }

    SafeNormalize();
  }

  Quaternion(const Quaternion &other):Vector<4,T>(other.v){}

  using Vector<4,T>::operator=;
  using Vector<4,T>::operator[];
  
  Vector<3,T> &XYZ()
  {
    return *(reinterpret_cast<Vector<3,T>*>(&v));
  }
  const Vector<3,T> &XYZ() const
  {
    return *(reinterpret_cast<const Vector<3,T>*>(&v));
  }
  T &W()
  {
    return v[3];
  }
  const T &W() const
  {
    return v[3];
  }
  inline Quaternion operator~() const 
  {
    return Quaternion(-v[0], -v[1], -v[2], v[3]);
  }

  // (xyz + w) * u = Cross(xyz,u) + w * u - Dot(xyz,u)
  inline friend Quaternion<T> operator*(const Quaternion &q, const VectorBase<3,T> &u) 
  {
    Vector<3,T> xyz(Cross(q.XYZ(), u) + q.W() * u);
    T w = -FastDot(q.XYZ(), u);
    return Quaternion(xyz[0], xyz[1], xyz[2], w);  
  }

  // u * (xyz + w) = Cross(u, xyz) + u * w - Dot(xyz,u)
  inline friend Vector<3,T> operator*(const VectorBase<3,T> &u, const Quaternion &q) 
  {
    Vector<3,T> xyz(Cross(u, q.XYZ()) + u * q.W());
    T w = -FastDot(q.XYZ(), u);
    return Quaternion(xyz[0], xyz[1], xyz[2], w);  
  }   

  // (xyz1 + w1) * (xyz2 * w2) = Cross(xyz1, xyz2) + xyz1 * w2 + xyz2 * w1 + w1 * w2 - Dot(xyz1, xyz2)
  inline Quaternion operator*(const Quaternion &other) const
  {
    Vector<3,T> xyz(Cross(XYZ(), other.XYZ()) + XYZ()*other.W() + W()*other.XYZ());
    T w = W()*other.W() - FastDot(XYZ(), other.XYZ());
    return Quaternion(xyz[0], xyz[1], xyz[2], w);  
  }

  void Identical()
  {v[0]=0; v[1]=0; v[2]=0; v[3]=1;}

  inline T Angle(){return 2*acos(W());}

  inline Vector<3,T> Axis(){return XYZ().Normalized();}

  inline VectorBase<3,T> RotateVector(const VectorBase<3,T> &u) const
  {
    Quaternion<T> q = (*this)*u*(~(*this));
    return q.XYZ();
  }

  inline VectorBase<3,T> RotateBackVector(const VectorBase<3,T> &u) const
  {
    Quaternion<T> q((~(*this))*u*(*this));
    return q.XYZ();
  }

  inline void SetAxisAngle(const VectorBase<3,T> &_axis, const T _angle)
  {
    XYZ()=_axis.Normalized()*sin(_angle/2.0);
    W()=cos(_angle/2.0);
  }

  // t is Distance(q, q1) / Distance(q1, q2)
  static Quaternion<T> Slerp(const Quaternion<T> &q1, const Quaternion<T> &q2, const T t)
  {
    // If q1 and q2 are colinear, there is no rotation between q1 and q2.
    T cos_val = FastDot(q1, q2);
    if (cos_val > 1.0 || cos_val < -1.0)
      return q2;

    T factor = 1.0;
    if (cos_val < 0.0) {
      factor = -1.0;
      cos_val = -cos_val;
    }

    T angle = acos(cos_val);
    if (angle < EPSILON)
      return q2;

    T sin_val = sin(angle);
    T inv_sin_val = 1.0 / sin_val;
    T coeff1 = sin((1.0-t)*angle)*inv_sin_val;
    T coeff2 = sin(t*angle)*inv_sin_val;
    Vector<4,T> v = static_cast<Vector<4,T> >(q1) * coeff1 +
                  static_cast<Vector<4,T> >(q2) * (factor * coeff2);
    return Quaternion<T>(v);
  }

  void GetRoll(){return atan(2(v[0]*v[1]+v[2]*v[3])/(1-2*(v[1]*v[1]+v[2]*v[2]))); }
  void GetPitch(){return asin(2(v[0]*v[2]-v[3]*v[1])); }
  void GetYaw(){return atan(2(v[0]*v[3]+v[1]*v[2])/(1-2*(v[2]*v[2]+v[3]*v[3]))); }
};

typedef Quaternion<zfloat32> Quaternionf32;
typedef Quaternion<zfloat64> Quaternionf64;

typedef Quaternion<float> Quaternionf;
typedef Quaternion<double> Quaterniond;
}
